
import torch.backends.cudnn as cudnn

import shutil
import argparse
import time

from dataset.data import *

import copy
from pruner import *




def parse_args():
    # hyper-parameters are from ResNet paper
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 training')
    parser.add_argument('cmd', choices=['train', 'test'])
    parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'])
    parser.add_argument('--data-dir', default='/home/xxx/cifar100', type=str,
                        help='the diretory to save cifar100 dataset')
    parser.add_argument('arch', metavar='ARCH', default='multi_resnet50_kd',
                        help='model architecture')
    parser.add_argument('--dataset', '-d', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100', 'imagenet'],
                        help='dataset choice')
    parser.add_argument('--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 4 )')
    parser.add_argument('--epoch', default=200, type=int,
                        help='number of total iterations (default: 64,000)')
    parser.add_argument('--start-epoch', default=0, type=int,
                        help='manual iter number (useful on restarts)')
    parser.add_argument('--batch-size', default=128, type=int,
                        help='mini-batch size (default: 128)')
    parser.add_argument('--valid-batch-size', default=64, type=int,
                        help='mini-batch size (default: 128)')
    parser.add_argument('--lr', default=0.1, type=float,
                        help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight-decay', default=5e-4, type=float,
                        help='weight decay (default: 1e-4)')
    parser.add_argument('--print-freq', default=100, type=int,
                        help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str,
                        help='path to  latest checkpoint (default: None)')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pretrained model')
    parser.add_argument('--step-ratio', default=0.1, type=float,
                        help='ratio for learning rate deduction')
    parser.add_argument('--warm-up', action='store_true',
                        help='for n = 18, the model needs to warm up for 400 '
                             'iterations')
    parser.add_argument('--save-folder', default='semme_save_checkpoints/', type=str,
                        help='folder to save the checkpoints')
    parser.add_argument('--summary-folder', default='runs/', type=str,
                        help='folder to save the summary')
    parser.add_argument('--eval-every', default=1000, type=int,
                        help='evaluate model every (default: 1000) iterations')
    parser.add_argument('--debug', default=0, type=int,
                        help='evaluate model every (default: 1000) iterations')
    #kd parameter
    parser.add_argument('--temperature', default=3, type=int,
                        help='temperature to smooth the logits')
    parser.add_argument('--alpha', default=0.1, type=float,
                        help='weight of kd loss')
    parser.add_argument('--beta', default=1e-6, type=float,
                        help='weight of feature loss')
    parser.add_argument('--blockprobs', type=float, nargs='+', help='<Required> Set flag', required=False)
    parser.add_argument('--score', default='random', type=str,
                        help='random,snip')
    parser.add_argument('--eta', default=0.05, type=float,
                        help='learning rate for weights')    
    parser.add_argument('--headprobs', type=float, nargs='+', help='<Required> Set flag', required=False)
    parser.add_argument('--update_w_step', default=10, type=int,
                        help='temperature to smooth the logits')
    args = parser.parse_args()
    return args

def main():
    args = parse_args()

    save_path = args.save_path = os.path.join(args.save_folder, args.arch,time.strftime("%Y%m%d_%H%M%S"))
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))
    handlers = [logging.FileHandler(args.logger_file, mode='w'),
                logging.StreamHandler()]
    logging.basicConfig(level=logging.INFO,
                        datefmt='%m-%d-%y %H:%M',
                        format='%(asctime)s:%(message)s',
                        handlers=handlers)
    logging.info(os.path.basename(__file__))
    logging.info(args)
    if args.cmd == 'train':
        logging.info('start training {}'.format(args.arch))
        ensemble_weights=run_training(args)

    elif args.cmd == 'test':
        logging.info('start evaluating {} with checkpoints from {}'.format(
            args.arch, args.resume))
        run_test(args,ensemble_weights)


def run_test(args,ensemble_weights):
    if args.dataset == 'cifar100':
        model = models.__dict__[args.arch](num_classes=100)
    else:
        raise NotImplementedError


    model = torch.nn.DataParallel(model).to(args.device)
    model = model.to(args.device)
    # load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logging.info("=> loading checkpoint `{}`".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))
            exit()

    cudnn.benchmark = True

    #load datasets
    if args.dataset == 'cifar100':
        test_loader = prepare_cifar100_test_dataset(data_dir=args.data_dir, batch_size=args.valid_batch_size/2,
                                                        num_workers=args.workers)
    elif args.dataset == 'cifar10':
        test_loader = prepare_cifar10_test_dataset(data_dir=args.data_dir, batch_size=args.valid_batch_size,
                                                        num_workers=args.workers)                                                                                   
    else:
        raise NotImplementedError
    criterion = nn.CrossEntropyLoss().to(args.device)
    validate(args, test_loader, model, criterion, ensemble_weights)

def run_training(args):
    w = torch.ones(4,requires_grad=False) / 4  ##############we have four heads
    if args.dataset == 'cifar100':
        train_loader = prepare_cifar100_train_dataset(data_dir=args.data_dir, batch_size=args.batch_size,
                                                        num_workers=args.workers)
        test_loader = prepare_cifar100_test_dataset(data_dir=args.data_dir, batch_size=args.valid_batch_size,
                                                        num_workers=args.workers)
    else:
        raise NotImplementedError

    if args.dataset == 'cifar100':
        model = models.__dict__[args.arch](num_classes=100)
        macs, params = profile(model.cpu(), inputs=(torch.randn(1, 3, 32, 32),))
        logging.info("Number of Parameters: %.1fM" % (params / 1e6))
        logging.info("Number of MACS: %.1fM FLOPS: %.1fM" % (macs / 1e6, 2 * macs / 1e6))
    else:
        raise NotImplementedError
    logging.info(model)
    # model = torch.nn.DataParallel(model).to(args.device)
    if args.arch == 'resnet18_4head' or args.arch == 'resnet50_4head' :
        model = prune_model_18(model,args,train_loader)
    else:
        raise  NotImplementedError
    best_prec1 = 0
    best_ensemble = 0
    logging.info(model)

    macs, params = profile(model.cpu(), inputs=(torch.randn(1, 3, 32, 32) , ))

    logging.info("Number of Parameters: %.1fM"%(params/1e6))
    logging.info("Number of MACS: %.1fM FLOPS: %.1fM"%(macs/1e6, 2*macs/1e6))
    model = torch.nn.DataParallel(model).to(args.device)
    torch.cuda.empty_cache()

    if args.resume:
        if os.path.isfile(args.resume):
            logging.info("=> loading checkpoint `{}`".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = True


    criterion = nn.CrossEntropyLoss().to(args.device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay = args.weight_decay)


    end = time.time()
    model.train()
    step = 0
    for current_epoch in range(args.start_epoch, args.epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        
        losses = AverageMeter()
        top1 = AverageMeter()
        
        middle1_losses = AverageMeter()
        middle2_losses = AverageMeter()
        middle3_losses = AverageMeter()
        middle4_losses = AverageMeter()        

        losses1_kd = AverageMeter()
        losses2_kd = AverageMeter()
        losses3_kd = AverageMeter()
        losses4_kd = AverageMeter()
        
        
        total_losses = AverageMeter()
        middle1_top1 = AverageMeter()
        middle2_top1 = AverageMeter()
        middle3_top1 = AverageMeter()
        middle4_top1 = AverageMeter()
        
        adjust_learning_rate(args, optimizer, current_epoch)
        
        for i, (input, target) in enumerate(train_loader):
            if args.debug:
                if i > 200:
                    break
            torch.cuda.empty_cache() 
            data_time.update(time.time() - end)

            target = target.squeeze().long().to(args.device)
            input = input.to(args.device)
            middle_output1, middle_output2, middle_output3, middle_output4= model(input)

            middle1_loss = criterion(middle_output1, target)
            middle1_losses.update(middle1_loss.item(), input.size(0))
            middle2_loss = criterion(middle_output2, target)
            middle2_losses.update(middle2_loss.item(), input.size(0))
            middle3_loss = criterion(middle_output3, target)
            middle3_losses.update(middle3_loss.item(), input.size(0))
            middle4_loss = criterion(middle_output4, target)
            middle4_losses.update(middle4_loss.item(), input.size(0))
            
            
            output_ensemble = (w[0] * middle_output1 + w[1] * middle_output2+w[2] * middle_output3 + w[3] * middle_output4)
            
            # temp_ensemble = output_ensemble / args.temperature
            # temp_ensemble = torch.softmax(temp_ensemble, dim=1)



            temp_middle_output1=middle_output1/ args.temperature
            temp_middle_output1 = torch.softmax(temp_middle_output1, dim=1)
            temp_middle_output2=middle_output2/ args.temperature
            temp_middle_output2 = torch.softmax(temp_middle_output2, dim=1)
            temp_middle_output3=middle_output3/ args.temperature
            temp_middle_output3 = torch.softmax(temp_middle_output3, dim=1)
            temp_middle_output4=middle_output4/ args.temperature
            temp_middle_output4 = torch.softmax(temp_middle_output4, dim=1)




            loss1by_2 = kd_loss_function(middle_output1, temp_middle_output2, args) * (args.temperature ** 2)
            loss1by_3 = kd_loss_function(middle_output1, temp_middle_output3, args) * (args.temperature ** 2)
            loss1by_4 = kd_loss_function(middle_output1, temp_middle_output4, args) * (args.temperature ** 2)

            loss2by_1 = kd_loss_function(middle_output2, temp_middle_output1, args) * (args.temperature ** 2)
            loss2by_3 = kd_loss_function(middle_output2, temp_middle_output3, args) * (args.temperature ** 2)
            loss2by_4 = kd_loss_function(middle_output2, temp_middle_output4, args) * (args.temperature ** 2)

            loss3by_1 = kd_loss_function(middle_output3, temp_middle_output1, args) * (args.temperature ** 2)
            loss3by_2 = kd_loss_function(middle_output3, temp_middle_output2, args) * (args.temperature ** 2)
            loss3by_4 = kd_loss_function(middle_output3, temp_middle_output4, args) * (args.temperature ** 2)

            loss4by_1 = kd_loss_function(middle_output4, temp_middle_output1, args) * (args.temperature ** 2)
            loss4by_2 = kd_loss_function(middle_output4, temp_middle_output2, args) * (args.temperature ** 2)
            loss4by_3 = kd_loss_function(middle_output4, temp_middle_output3, args) * (args.temperature ** 2)

            total_loss = (1 - args.alpha) * (w[0]*middle1_loss + w[1]*middle2_loss + w[2]* middle3_loss + w[3] * middle4_loss) + args.alpha * (w[0]*(loss2by_1+loss3by_1+loss4by_1)+w[1]*(loss1by_2+loss3by_2+loss4by_2)+w[2]*(loss1by_3+loss2by_3+loss4by_3)+w[3]*(loss1by_4+loss2by_4+loss3by_4))
            
            total_losses.update(total_loss.item(), input.size(0))

            middle1_prec1 = accuracy(middle_output1.data, target, topk=(1,))
            middle1_top1.update(middle1_prec1[0], input.size(0))
            middle2_prec1 = accuracy(middle_output2.data, target, topk=(1,))
            middle2_top1.update(middle2_prec1[0], input.size(0))
            middle3_prec1 = accuracy(middle_output3.data, target, topk=(1,))
            middle3_top1.update(middle3_prec1[0], input.size(0))
            middle4_prec1 = accuracy(middle_output4.data, target, topk=(1,))
            middle4_top1.update(middle4_prec1[0], input.size(0))            
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            del total_loss
            del middle_output1, middle_output2, middle_output3, middle_output4

            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:

                step += 1
                logging.info("Epoch: [{0}]\t"
                            "Iter: [{1}]\t"
                            "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                            "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                            "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
                            "Prec@1 {middle1_top1.val:.3f} ({middle1_top1.avg:.3f})\t"
                            "Prec@2 {middle2_top1.val:.3f} ({middle2_top1.avg:.3f})\t"
                            "Prec@3 {middle3_top1.val:.3f} ({middle3_top1.avg:.3f})\t"
                            "Prec@4 {middle4_top1.val:.3f} ({middle4_top1.avg:.3f})\t".format(
                                current_epoch,
                                i,
                                batch_time=batch_time,
                                data_time=data_time,
                                loss=total_losses,
                                middle1_top1=middle1_top1,
                                middle2_top1=middle2_top1,  ############### m1t2 to m2t1
                                middle3_top1=middle3_top1,
                                middle4_top1=middle4_top1,)
                )
            torch.cuda.empty_cache() 
        if (current_epoch+1) % args.update_w_step==0 and current_epoch<args.epoch-1: ########## epoch to args.epoch
            w = update_w(args,test_loader, model, criterion, w, optimizer)

        top1_ensemble_ori, top1_ensemble = validate(args, test_loader, model, criterion, w)
        best_ensemble = max(best_ensemble,top1_ensemble)
        is_best_prec1 = top1_ensemble_ori > best_prec1
        best_prec1 = max(top1_ensemble_ori, best_prec1)
        logging.info('best ensemble ori: {:.2f}, best ensemble {:.2f}'.format(best_prec1,best_ensemble))
        # logging.info(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
        if is_best_prec1:
            checkpoint_path = os.path.join(args.save_path, 'checkpoint_best.pth.tar'.format(current_epoch))
            torch.save({
                'epoch': current_epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                }, checkpoint_path)
        checkpoint_path = os.path.join(args.save_path, 'checkpoint_current_epoch.pth.tar'.format(current_epoch))
        torch.save({
            'epoch': current_epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            }, checkpoint_path)
        torch.cuda.empty_cache()
    return w



def update_w(args, train_loader, model, criterion, a, optimizer):
    
    torch.cuda.empty_cache()
    optimizer.zero_grad()
    
    for i, (input, target) in enumerate(train_loader):
        break
########## please note that we only use one batch to update w with one step.It is better to isolate these data as validation set, while we dismiss it for simplicity.
    target = target.squeeze().long().to(args.device)
    input = input.to(args.device)

    middle_output1, middle_output2, middle_output3, middle_output4= model(input)
    
    output_average = (middle_output1 + middle_output2 + middle_output3+middle_output4)/4
    
    middle_output1_soft = middle_output1 / args.temperature
    middle_output1_soft = torch.softmax(middle_output1_soft, dim=1)
    middle_output2_soft = middle_output2 / args.temperature
    middle_output2_soft = torch.softmax(middle_output2_soft, dim=1)
    middle_output3_soft = middle_output3 / args.temperature
    middle_output3_soft = torch.softmax(middle_output3_soft, dim=1)
    middle_output4_soft = middle_output4 / args.temperature
    middle_output4_soft = torch.softmax(middle_output4_soft, dim=1)
    

    loss_11 = (1-args.alpha)*criterion(middle_output1, target) + args.alpha * 4 * kd_loss_function(output_average, middle_output1_soft.detach(), args) * (args.temperature**2)    
    
    loss_11.backward()                 

    #if i>2:
        #break
    
    grads_11 = []
    for p in model.parameters():
        if p.grad is not None:
            grads_11 = grads_11+ list(p.grad.data/(i+1))        
    grad_11=copy.deepcopy(grads_11)     
    
    torch.cuda.empty_cache()
    optimizer.zero_grad()
    target = target.squeeze().long().to(args.device)
    input = input.to(args.device)

    middle_output1, middle_output2, middle_output3, middle_output4= model(input)
    output_average = (middle_output1 + middle_output2 + middle_output3+middle_output4)/4
    
    middle_output1_soft = middle_output1 / args.temperature
    middle_output1_soft = torch.softmax(middle_output1_soft, dim=1)
    middle_output2_soft = middle_output2 / args.temperature
    middle_output2_soft = torch.softmax(middle_output2_soft, dim=1)
    middle_output3_soft = middle_output3 / args.temperature
    middle_output3_soft = torch.softmax(middle_output3_soft, dim=1)
    middle_output4_soft = middle_output4 / args.temperature
    middle_output4_soft = torch.softmax(middle_output4_soft, dim=1)    
    loss_12=(1-args.alpha)*criterion(middle_output2, target) + args.alpha *4*kd_loss_function(output_average, middle_output2_soft.detach(), args) * (args.temperature**2)    
    loss_12.backward()

    grads_12 = []
    for p in model.parameters():
        if p.grad is not None:
            grads_12 = grads_12+ list(p.grad.data/(i+1))        
    grad_12=copy.deepcopy(grads_12)      
    del loss_12, middle_output1, middle_output2, middle_output3, middle_output4



    torch.cuda.empty_cache()    
    optimizer.zero_grad()
    target = target.squeeze().long().to(args.device)
    input = input.to(args.device)
    middle_output1, middle_output2, middle_output3, middle_output4= model(input)    
    output_average = (middle_output1 + middle_output2 + middle_output3+middle_output4)/4
    
    middle_output1_soft = middle_output1 / args.temperature
    middle_output1_soft = torch.softmax(middle_output1_soft, dim=1)
    middle_output2_soft = middle_output2 / args.temperature
    middle_output2_soft = torch.softmax(middle_output2_soft, dim=1)
    middle_output3_soft = middle_output3 / args.temperature
    middle_output3_soft = torch.softmax(middle_output3_soft, dim=1)
    middle_output4_soft = middle_output4 / args.temperature
    middle_output4_soft = torch.softmax(middle_output4_soft, dim=1)    
    loss_13=(1-args.alpha)*criterion(middle_output3, target) + args.alpha *4*kd_loss_function(output_average, middle_output3_soft.detach(), args) * (args.temperature**2)     
    loss_13.backward()                     
    grads_13 = []
    for p in model.parameters():
        if p.grad is not None:
            grads_13 = grads_13+ list(p.grad.data/(i+1))        
    grad_13=copy.deepcopy(grads_13)
    del loss_13, middle_output1, middle_output2, middle_output3, middle_output4

    torch.cuda.empty_cache()    
    optimizer.zero_grad()
    target = target.squeeze().long().to(args.device)
    input = input.to(args.device)
    middle_output1, middle_output2, middle_output3, middle_output4= model(input)
    output_average = (middle_output1 + middle_output2 + middle_output3+middle_output4)/4
    
    middle_output1_soft = middle_output1 / args.temperature
    middle_output1_soft = torch.softmax(middle_output1_soft, dim=1)
    middle_output2_soft = middle_output2 / args.temperature
    middle_output2_soft = torch.softmax(middle_output2_soft, dim=1)
    middle_output3_soft = middle_output3 / args.temperature
    middle_output3_soft = torch.softmax(middle_output3_soft, dim=1)
    middle_output4_soft = middle_output4 / args.temperature
    middle_output4_soft = torch.softmax(middle_output4_soft, dim=1)
    loss_14=(1-args.alpha)*criterion(middle_output4, target) + args.alpha *4*kd_loss_function(output_average, middle_output4_soft.detach(), args) * (args.temperature**2)    
    loss_14.backward()           
    grads_14 = []
    for p in model.parameters():
        if p.grad is not None:
            grads_14 = grads_14+ list(p.grad.data/(i+1))        
    grad_14=copy.deepcopy(grads_14)
    del loss_14, middle_output1, middle_output2, middle_output3, middle_output4

    torch.cuda.empty_cache()    
    optimizer.zero_grad()
    target = target.squeeze().long().to(args.device)
    input = input.to(args.device)
    
    copy_a=torch.autograd.Variable((torch.ones(4) / 4).type(torch.FloatTensor), requires_grad=True)
    copy_a.data=copy.deepcopy(a.data)        
    
    middle_output1, middle_output2, middle_output3, middle_output4= model(input)
    middle1_loss = criterion(middle_output1, target)
    middle2_loss = criterion(middle_output2, target)
    middle3_loss = criterion(middle_output3, target)
    middle4_loss = criterion(middle_output4, target)
    
    
    loss_L2_weight=copy_a[0]*middle1_loss+copy_a[1]*middle2_loss+copy_a[2]*middle3_loss+copy_a[3]*middle4_loss
    loss_L2_weight.backward()                 
    grads_2= []
    for p in model.parameters():
        if p.grad is not None:
            grads_2 = grads_2+ list(p.grad.data/(i+1))        
    grad_2=copy.deepcopy(grads_2)

    del loss_L2_weight, middle_output1, middle_output2, middle_output3, middle_output4
    
    lr=optimizer.param_groups[0]['lr']  
       
    dot11=torch.as_tensor([(v*p).sum() for v,p in zip(grad_11, grad_2)]).sum()
    sqrt_norm_11=torch.sqrt(torch.as_tensor([(v*v).sum() for v in grad_11]).sum())
    sqrt_norm_2=torch.sqrt(torch.as_tensor([(v*v).sum() for v in grad_2]).sum())
    a[0] = a[0]*torch.exp(args.eta*(lr*dot11-copy_a.grad.data[0]))  
    
    
    dot12=torch.as_tensor([(v*p).sum() for v,p in zip(grad_12, grad_2)]).sum()
    sqrt_norm_12=torch.sqrt(torch.as_tensor([(v*v).sum() for v in grad_12]).sum())
    a[1] = a[1]*torch.exp(args.eta*(lr*dot12-copy_a.grad.data[1]))  

    dot13=torch.as_tensor([(v*p).sum() for v,p in zip(grad_13, grad_2)]).sum()
    sqrt_norm_13=torch.sqrt(torch.as_tensor([(v*v).sum() for v in grad_13]).sum())
    a[2] = a[2]*torch.exp(args.eta*(lr*dot13-copy_a.grad.data[2]))  
    
    dot14=torch.as_tensor([(v*p).sum() for v,p in zip(grad_14, grad_2)]).sum()
    sqrt_norm_14=torch.sqrt(torch.as_tensor([(v*v).sum() for v in grad_14]).sum())
    a[3] = a[3]*torch.exp(args.eta*(lr*dot14-copy_a.grad.data[3]))  
    
    
    a= a/a.sum()    
    torch.cuda.empty_cache()
    optimizer.zero_grad()  
    
    logging.info('updated weights ---------------------------------------')
    logging.info(a)
    
    return a

def validate(args, test_loader, model, criterion, updated_w):
    batch_time = AverageMeter()
    
    losses_ensemble = AverageMeter()    
    losses_ensemble_ori = AverageMeter()
    middle1_losses = AverageMeter()
    middle2_losses = AverageMeter()
    middle3_losses = AverageMeter()
    middle4_losses = AverageMeter()
    
    top1_ensemble_ori = AverageMeter()
    top1_ensemble = AverageMeter()    
    middle1_top1 = AverageMeter()
    middle2_top1 = AverageMeter()
    middle3_top1 = AverageMeter()
    middle4_top1 = AverageMeter() 
    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            if args.debug:
                if i > 200:
                    break
            torch.cuda.empty_cache()
            target = target.squeeze().long().to(args.device)
            input = input.to(args.device)

            middle_output1, middle_output2, middle_output3, middle_output4= model(input)

            output_ensemble=(updated_w[0]*middle_output1+updated_w[1]*middle_output2+updated_w[2]*middle_output3+updated_w[3]*middle_output4)

            output_ensemble_ori=(middle_output1+middle_output2 +middle_output3+middle_output4)/4

            ensemble_loss = criterion(output_ensemble, target)
            losses_ensemble.update(ensemble_loss.item(), input.size(0))

            ensemble_loss_ori = criterion(output_ensemble_ori, target)
            losses_ensemble_ori.update(ensemble_loss_ori.item(), input.size(0))


            middle1_loss = criterion(middle_output1, target)
            middle1_losses.update(middle1_loss.item(), input.size(0))
            middle2_loss = criterion(middle_output2, target)
            middle2_losses.update(middle2_loss.item(), input.size(0))
            middle3_loss = criterion(middle_output3, target)
            middle3_losses.update(middle3_loss.item(), input.size(0))
            middle4_loss = criterion(middle_output4, target)
            middle4_losses.update(middle4_loss.item(), input.size(0))



            prec1_ensemble_ori = accuracy(output_ensemble_ori.data, target, topk=(1,))
            top1_ensemble_ori.update(prec1_ensemble_ori[0], input.size(0))

            prec1_ensemble = accuracy(output_ensemble.data, target, topk=(1,))
            top1_ensemble.update(prec1_ensemble[0], input.size(0))


            middle1_prec1 = accuracy(middle_output1.data, target, topk=(1,))
            middle1_top1.update(middle1_prec1[0], input.size(0))
            middle2_prec1 = accuracy(middle_output2.data, target, topk=(1,))
            middle2_top1.update(middle2_prec1[0], input.size(0))
            middle3_prec1 = accuracy(middle_output3.data, target, topk=(1,))
            middle3_top1.update(middle3_prec1[0], input.size(0))
            middle4_prec1 = accuracy(middle_output4.data, target, topk=(1,))
            middle4_top1.update(middle4_prec1[0], input.size(0))




            batch_time.update(time.time() - end)
            end = time.time()
    logging.info("Loss_ensemble {ensemble_loss.avg:.3f}\t"
                 "Prec_ensemble@1 {top1_ensemble.avg:.3f}\t" 
                 "Loss_ensemble_ori {ensemble_loss_ori.avg:.3f}\t"          
                 "Prec_ensemble_ori@1 {top1_ensemble_ori.avg:.3f}\t" 
                 "Middle1@1 {middle1_top1.avg:.3f}\t"
                 "Middle2@1 {middle2_top1.avg:.3f}\t"          
                 "Middle3@1 {middle3_top1.avg:.3f}\t"
                 "Middle4@1 {middle4_top1.avg:.3f}\t".format(
                    ensemble_loss=losses_ensemble,
                    top1_ensemble=top1_ensemble,
                    ensemble_loss_ori=losses_ensemble_ori,
                    top1_ensemble_ori=top1_ensemble_ori,
                    middle1_top1=middle1_top1,
                    middle2_top1=middle2_top1,
                    middle3_top1=middle3_top1,
                    middle4_top1=middle4_top1))
    torch.cuda.empty_cache()
    model.train()
    return top1_ensemble_ori.avg, top1_ensemble.avg

def kd_loss_function(output, target_output,args):
    """Compute kd loss"""
    """
    para: output: middle ouptput logits.
    para: target_output: final output has divided by temperature and softmax.
    """

    output = output / args.temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd

def feature_loss_function(fea, target_fea):
    loss = (fea - target_fea)**2 * ((fea > 0) | (target_fea > 0)).float()
    return torch.abs(loss).sum()

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(args, optimizer, epoch):
    if args.warm_up and (epoch < 1):
        lr = 0.01
    elif 75 <= epoch < 130:
        lr = args.lr * (args.step_ratio ** 1)
    elif 130 <= epoch < 180:
        lr = args.lr * (args.step_ratio ** 2)
    elif epoch >=180:
        lr = args.lr * (args.step_ratio ** 3)
    else:
        lr = args.lr


    logging.info('Epoch [{}] learning rate = {}'.format(epoch, lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul(100.0 / batch_size))

    return res

def save_checkpoint(state, is_best, filename):
    torch.save(state, filename)
    if is_best:
        save_path = os.path.dirname(filename)
        shutil.copyfile(filename, os.path.join(save_path, 'model_best.path.tar'))

if __name__ == '__main__':
    main()


